#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2025/7/23
# @USER    : Shengji He
# @File    : model.py
# @Software: PyCharm
# @Version  : Python-
# @TASK:
import os
import glob
from typing import Union
import time
import numpy as np
import nibabel as nib
from nibabel.nifti1 import Nifti1Image

from SegFlow.map_to_binary import class_map, class_map_5_parts, map_taskid_to_partname_ct
from SegFlow.resampling import change_spacing
from SegFlow.postprocessing import keep_largest_blob_multilabel, remove_small_blobs_multilabel, remove_auxiliary_labels
from SegFlow.cropping import crop_to_mask, undo_crop
from SegFlow.nifti_ext_header import add_label_map_to_nifti

from SegFlow.utils import nostdout


# from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor


def combine_masks(mask_dir, class_type):
    """
    Combine classes to masks

    mask_dir: directory of totalsegmetator masks
    class_type: ribs | vertebrae | vertebrae_ribs | lung | heart

    returns: nibabel image
    """
    rib_classes = [f"rib_left_{idx}" for idx in range(1, 13)] + [f"rib_right_{idx}" for idx in
                                                                 range(1, 13)]  # + ["sternum",]
    if class_type == "ribs":
        masks = rib_classes
    # elif class_type == "vertebrae":
    #     masks = list(class_map_5_parts["class_map_part_vertebrae"].values())
    # elif class_type == "vertebrae_ribs":
    #     masks = list(class_map_5_parts["class_map_part_vertebrae"].values()) + rib_classes
    elif class_type == "lung":
        masks = ["lung_upper_lobe_left", "lung_lower_lobe_left", "lung_upper_lobe_right",
                 "lung_middle_lobe_right", "lung_lower_lobe_right"]
    elif class_type == "lung_left":
        masks = ["lung_upper_lobe_left", "lung_lower_lobe_left"]
    elif class_type == "lung_right":
        masks = ["lung_upper_lobe_right", "lung_middle_lobe_right", "lung_lower_lobe_right"]
    elif class_type == "pelvis":
        masks = ["femur_left", "femur_right", "hip_left", "hip_right"]
    elif class_type == "body":
        masks = ["body_trunc", "body_extremities"]

    ref_img = None
    for mask in masks:
        if (mask_dir / f"{mask}.nii.gz").exists():
            ref_img = nib.load(mask_dir / f"{masks[0]}.nii.gz")
        else:
            raise ValueError(f"Could not find {mask_dir / mask}.nii.gz. Did you run TotalSegmentator successfully?")

    combined = np.zeros(ref_img.shape, dtype=np.uint8)
    for idx, mask in enumerate(masks):
        if (mask_dir / f"{mask}.nii.gz").exists():
            img = nib.load(mask_dir / f"{mask}.nii.gz").get_fdata()
            combined[img > 0.5] = 1

    return nib.Nifti1Image(combined, ref_img.affine)


def check_if_shape_and_affine_identical(img_1, img_2):
    max_diff = np.abs(img_1.affine - img_2.affine).max()
    if max_diff > 1e-5:
        print("Affine in:")
        print(img_1.affine)
        print("Affine out:")
        print(img_2.affine)
        print("Diff:")
        print(np.abs(img_1.affine - img_2.affine))
        print("WARNING: Output affine not equal to input affine. This should not happen.")

    if img_1.shape != img_2.shape:
        print("Shape in:")
        print(img_1.shape)
        print("Shape out:")
        print(img_2.shape)
        print("WARNING: Output shape not equal to input shape. This should not happen.")


def reorder_multilabel_like_v1(data, label_map_v2, label_map_v1):
    """
    Reorder a multilabel image from v2 to v1
    """
    label_map_v2_inv = {v: k for k, v in label_map_v2.items()}
    data_out = np.zeros(data.shape, dtype=np.uint8)
    for label_id, label_name in label_map_v1.items():
        if label_name in label_map_v2_inv:
            data_out[data == label_map_v2_inv[label_name]] = label_id
        # heart chambers are not in v2 anymore. The results seg will be empty for these classes
    return data_out


# ---------------------------------------------------------
def find_candidate_datasets(dataset_id: int):
    startswith = "Dataset%03.0d" % dataset_id
    # if nnUNet_preprocessed is not None and isdir(nnUNet_preprocessed):
    #     candidates_preprocessed = subdirs(nnUNet_preprocessed, prefix=startswith, join=False)
    # else:
    #     candidates_preprocessed = []
    #
    # if nnUNet_raw is not None and isdir(nnUNet_raw):
    #     candidates_raw = subdirs(nnUNet_raw, prefix=startswith, join=False)
    # else:
    #     candidates_raw = []
    #
    # candidates_trained_models = []
    # if nnUNet_results is not None and isdir(nnUNet_results):
    #     candidates_trained_models += subdirs(nnUNet_results, prefix=startswith, join=False)
    #
    # all_candidates = candidates_preprocessed + candidates_raw + candidates_trained_models
    # unique_candidates = np.unique(all_candidates)
    glob_research = os.path.join('../weights', f'{startswith}*')
    files = glob.glob(glob_research)
    all_candidates = []
    for file in files:
        if os.path.isdir(file):
            all_candidates.append(file)
    unique_candidates = tuple(set(all_candidates))
    return unique_candidates


def convert_id_to_dataset_name(dataset_id: int):
    unique_candidates = find_candidate_datasets(dataset_id)
    if len(unique_candidates) > 1:
        raise RuntimeError("More than one dataset name found for dataset id %d. Please correct that. (I looked in the "
                           "following folders:\n%s\n%s\n%s" % (dataset_id, nnUNet_raw, nnUNet_preprocessed,
                                                               nnUNet_results))
    if len(unique_candidates) == 0:
        raise RuntimeError(f"Could not find a dataset with the ID {dataset_id}. Make sure the requested dataset ID "
                           f"exists and that nnU-Net knows where raw and preprocessed data are located "
                           f"(see Documentation - Installation). Here are your currently defined folders:\n"
                           f"nnUNet_preprocessed={os.environ.get('nnUNet_preprocessed') if os.environ.get('nnUNet_preprocessed') is not None else 'None'}\n"
                           f"nnUNet_results={os.environ.get('nnUNet_results') if os.environ.get('nnUNet_results') is not None else 'None'}\n"
                           f"nnUNet_raw={os.environ.get('nnUNet_raw') if os.environ.get('nnUNet_raw') is not None else 'None'}\n"
                           f"If something is not right, adapt your environment variables.")
    return unique_candidates[0]


def maybe_convert_to_dataset_name(dataset_name_or_id: Union[int, str]) -> str:
    if isinstance(dataset_name_or_id, str) and dataset_name_or_id.startswith("Dataset"):
        return dataset_name_or_id
    if isinstance(dataset_name_or_id, str):
        try:
            dataset_name_or_id = int(dataset_name_or_id)
        except ValueError:
            raise ValueError("dataset_name_or_id was a string and did not start with 'Dataset' so we tried to "
                             "convert it to a dataset ID (int). That failed, however. Please give an integer number "
                             "('1', '2', etc) or a correct tast name. Your input: %s" % dataset_name_or_id)
    return convert_id_to_dataset_name(dataset_name_or_id)


def convert_trainer_plans_config_to_identifier(trainer_name, plans_identifier, configuration):
    return f'{trainer_name}__{plans_identifier}__{configuration}'


def get_output_folder(dataset_name_or_id: Union[str, int], trainer_name: str = 'nnUNetTrainer',
                      plans_identifier: str = 'nnUNetPlans', configuration: str = '3d_fullres',
                      fold: Union[str, int] = None) -> str:
    tmp = os.path.join('../weights', maybe_convert_to_dataset_name(dataset_name_or_id),
                       convert_trainer_plans_config_to_identifier(trainer_name, plans_identifier, configuration))
    if fold is not None:
        tmp = os.path.join(tmp, f'fold_{fold}')
    return tmp


"""
alignment
"""


def as_closest_canonical(img_in):
    """
    Convert the given nifti file to the closest canonical nifti file.
    """
    return nib.as_closest_canonical(img_in)


def undo_canonical(img_can, img_orig):
    """
    Inverts nib.to_closest_canonical()

    img_can: the image we want to move back
    img_orig: the original image because transforming to canonical

    returns image in original space

    https://github.com/nipy/nibabel/issues/1063
    """
    img_ornt = nib.orientations.io_orientation(img_orig.affine)
    ras_ornt = nib.orientations.axcodes2ornt("RAS")

    to_canonical = img_ornt  # Same as ornt_transform(img_ornt, ras_ornt)
    from_canonical = nib.orientations.ornt_transform(ras_ornt, img_ornt)

    # Same as as_closest_canonical
    # img_canonical = img_orig.as_reoriented(to_canonical)

    return img_can.as_reoriented(from_canonical)


"""
nnunet
"""


# def supports_keyword_argument(func, keyword: str):
#     """
#     Check if a function supports a specific keyword argument.
#
#     Returns:
#     - True if the function supports the specified keyword argument.
#     - False otherwise.
#     """
#     signature = inspect.signature(func)
#     parameters = signature.parameters
#     return keyword in parameters


def nnUNetv2_predict(dir_in, dir_out, task_id, model="3d_fullres", folds=None,
                     trainer="nnUNetTrainer", tta=False,
                     num_threads_preprocessing=3, num_threads_nifti_save=2,
                     plans="nnUNetPlans", device="cuda", quiet=False, step_size=0.5):
    """
    Identical to bash function nnUNetv2_predict

    folds:  folds to use for prediction. Default is None which means that folds will be detected
            automatically in the model output folder.
            for all folds: None
            for only fold 0: [0]
    """
    model_folder = get_output_folder(task_id, trainer, plans, model)
    # model_folder = r'../weights/Dataset299_body_1559subj/nnUNetTrainer__nnUNetPlans__3d_fullres'

    disable_tta = not tta
    verbose = False
    save_probabilities = False
    continue_prediction = False
    chk = "checkpoint_final.pth"
    npp = num_threads_preprocessing
    nps = num_threads_nifti_save
    prev_stage_predictions = None
    num_parts = 1
    part_id = 0
    allow_tqdm = not quiet

    # nnUNet 2.2.1
    if supports_keyword_argument(nnUNetPredictor, "perform_everything_on_gpu"):
        predictor = nnUNetPredictor(
            tile_step_size=step_size,
            use_gaussian=True,
            use_mirroring=not disable_tta,
            perform_everything_on_gpu=True,  # for nnunetv2<=2.2.1
            device=device,
            verbose=verbose,
            verbose_preprocessing=verbose,
            allow_tqdm=allow_tqdm
        )
    # nnUNet >= 2.2.2
    # else:
    #     predictor = nnUNetPredictor(
    #         tile_step_size=step_size,
    #         use_gaussian=True,
    #         use_mirroring=not disable_tta,
    #         perform_everything_on_device=True,  # for nnunetv2>=2.2.2
    #         device=device,
    #         verbose=verbose,
    #         verbose_preprocessing=verbose,
    #         allow_tqdm=allow_tqdm
    #     )
    predictor.initialize_from_trained_model_folder(
        model_folder,
        use_folds=folds,
        checkpoint_name=chk,
    )
    predictor.predict_from_files(dir_in, dir_out,
                                 save_probabilities=save_probabilities, overwrite=not continue_prediction,
                                 num_processes_preprocessing=npp, num_processes_segmentation_export=nps,
                                 folder_with_segs_from_prev_stage=prev_stage_predictions,
                                 num_parts=num_parts, part_id=part_id)

    # # Use numpy as input. TODO: In entire pipeline do not save to disk
    # input_image = nib.load(Path(dir_in) / "s01_0000.nii.gz")
    # input_data = np.asanyarray(input_image.dataobj).transpose(2, 1, 0)[None,...].astype(np.float32)
    # spacing = input_image.header.get_zooms()
    # affine = input_image.affine
    # # Do i have to transpose spacing? does not matter because anyways isotropic at this point.
    # spacing = (spacing[2], spacing[1], spacing[0])
    # props = {"spacing": spacing}
    # # from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO
    # # input_data, props = SimpleITKIO().read_images([os.path.join(dir_in, "s01_0000.nii.gz")])
    # seg = predictor.predict_single_npy_array(input_data, props,
    #                                          prev_stage_predictions, None,
    #                                          save_probabilities)
    # seg = seg.transpose(2, 1, 0)
    # nib.save(nib.Nifti1Image(seg.astype(np.uint8), affine), Path(dir_out) / "s01.nii.gz")


def nnUNet_predict_image_body(file_in: Union[str, Nifti1Image], file_out, task_id, model="3d_fullres", folds=None,
                              trainer="nnUNetTrainerV2", tta=False, multilabel_image=True, resample=None,
                              task_name="total",
                              save_binary=False, nr_threads_resampling=1, nr_threads_saving=6, quiet=False,
                              verbose=False,
                              skip_saving=False, device="cuda", no_derived_masks=False, tmp_dir=''):
    """
    crop: string or a nibabel image
    resample: None or float (target spacing for all dimensions) or list of floats
    """

    if type(resample) is float:
        resample = [resample, resample, resample]

    if isinstance(file_in, Nifti1Image):
        img_in_orig = file_in
    else:
        img_in_orig = nib.load(file_in)

    if len(img_in_orig.shape) == 2:
        raise ValueError("TotalSegmentator does not work for 2D images. Use a 3D image.")
    if len(img_in_orig.shape) > 3:
        print(f"WARNING: Input image has {len(img_in_orig.shape)} dimensions. Only using first three dimensions.")
        img_in_orig = nib.Nifti1Image(img_in_orig.get_fdata()[:, :, :, 0], img_in_orig.affine)

    img_dtype = img_in_orig.get_data_dtype()
    if img_dtype.fields is not None:
        raise TypeError(f"Invalid dtype {img_dtype}. Expected a simple dtype, not a structured one.")

    # takes ~0.9s for medium image
    img_in = nib.Nifti1Image(img_in_orig.get_fdata(), img_in_orig.affine)  # copy img_in_orig

    img_in = as_closest_canonical(img_in)

    if resample is not None:
        if not quiet:
            print("Resampling...")
        st = time.time()
        img_in_shape = img_in.shape
        img_in_zooms = img_in.header.get_zooms()
        img_in_rsp = change_spacing(img_in, resample,
                                    order=3, dtype=np.int32,
                                    nr_cpus=nr_threads_resampling)  # 4 cpus instead of 1 makes it a bit slower
        if verbose:
            print(f"  from shape {img_in.shape} to shape {img_in_rsp.shape}")
        if not quiet:
            print(f"  Resampled in {time.time() - st:.2f}s")
    else:
        img_in_rsp = img_in

    # nib.save(img_in_rsp, os.path.join(tmp_dir, "s01_0000.nii.gz"))

    # todo important: change
    nr_voxels_thr = 512 * 512 * 900
    # nr_voxels_thr = 256*256*900
    img_parts = ["s01"]
    # ss = img_in_rsp.shape
    step_size = 0.5

    st = time.time()
    if not quiet:
        print("Predicting...")
    with nostdout(verbose):
        # nnUNet_predict(tmp_dir, tmp_dir, task_id, model, folds, trainer, tta,
        #                nr_threads_resampling, nr_threads_saving)
        nnUNetv2_predict(tmp_dir, tmp_dir, task_id, model, folds, trainer, tta,
                         nr_threads_resampling, nr_threads_saving,
                         device=device, quiet=quiet, step_size=step_size)
    if not quiet:
        print(f"  Predicted in {time.time() - st:.2f}s")

    img_pred = nib.load(os.path.join(tmp_dir, "s01.nii.gz"))

    # Postprocessing multilabel (run here on lower resolution)
    img_pred_pp = keep_largest_blob_multilabel(img_pred.get_fdata().astype(np.uint8),
                                               class_map[task_name], ["body_trunc"], debug=False, quiet=quiet)
    img_pred = nib.Nifti1Image(img_pred_pp, img_pred.affine)

    vox_vol = np.prod(img_pred.header.get_zooms())
    size_thr_mm3 = 50000 / vox_vol
    img_pred_pp = remove_small_blobs_multilabel(img_pred.get_fdata().astype(np.uint8),
                                                class_map[task_name], ["body_extremities"],
                                                interval=[size_thr_mm3, 1e10], debug=False, quiet=quiet)
    img_pred = nib.Nifti1Image(img_pred_pp, img_pred.affine)

    if resample is not None:
        if not quiet:
            print("Resampling...")
        if verbose:
            print(f"  back to original shape: {img_in_shape}")
        # Use force_affine otherwise output affine sometimes slightly off (which then is even increased
        # by undo_canonical)
        img_pred = change_spacing(img_pred, resample, img_in_shape,
                                  order=0, dtype=np.uint8, nr_cpus=nr_threads_resampling,
                                  force_affine=img_in.affine)

    if verbose:
        print("Undoing canonical...")
    img_pred = undo_canonical(img_pred, img_in_orig)

    check_if_shape_and_affine_identical(img_in_orig, img_pred)

    img_data = img_pred.get_fdata().astype(np.uint8)
    if save_binary:
        img_data = (img_data > 0).astype(np.uint8)

    label_map = class_map[task_name]

    # Prepare output nifti
    # Copy header to make output header exactly the same as input. But change dtype otherwise it will be
    # float or int and therefore the masks will need a lot more space.
    # (infos on header: https://nipy.org/nibabel/nifti_images.html)
    new_header = img_in_orig.header.copy()
    new_header.set_data_dtype(np.uint8)
    img_out = nib.Nifti1Image(img_data, img_pred.affine, new_header)
    img_out = add_label_map_to_nifti(img_out, label_map)

    if file_out is not None and skip_saving is False:
        if not quiet:
            print("Saving segmentations...")

        # Select subset of classes if required
        selected_classes = class_map[task_name]

        st = time.time()
        if multilabel_image:
            os.makedirs(os.path.dirname(file_out), exist_ok=True)
            nib.save(img_out, file_out)
        else:
            os.makedirs(file_out, exist_ok=True)

            for k, v in selected_classes.items():
                binary_img = img_data == k
                output_path = os.path.join(file_out, f"{v}.nii.gz")
                nib.save(nib.Nifti1Image(binary_img.astype(np.uint8), img_pred.affine, new_header), output_path)

        if task_name == "body" and not multilabel_image and not no_derived_masks:
            if not quiet:
                print("Creating body.nii.gz")
            body_img = combine_masks(file_out, "body")
            nib.save(body_img, file_out / "body.nii.gz")
            # if not quiet: print("Creating skin.nii.gz")
            # skin = extract_skin(img_in_orig, nib.load(file_out / "body.nii.gz"))
            # nib.save(skin, file_out / "skin.nii.gz")


def nnUNet_predict_image(file_in: Union[str, Nifti1Image], file_out, task_id, model="3d_fullres", folds=None,
                         trainer="nnUNetTrainerV2", tta=False, multilabel_image=True,
                         resample=None, crop=None, crop_path=None, task_name="total", nora_tag="None", preview=False,
                         save_binary=False, nr_threads_resampling=1, nr_threads_saving=6, force_split=False,
                         crop_addon=[3, 3, 3], roi_subset=None, output_type="nifti",
                         statistics=False, quiet=False, verbose=False, test=0, skip_saving=False,
                         device="cuda", exclude_masks_at_border=True, no_derived_masks=False,
                         v1_order=False, stats_aggregation="mean", tmp_dir=''):
    """
    crop: string or a nibabel image
    resample: None or float (target spacing for all dimensions) or list of floats
    """
    if isinstance(file_in, Nifti1Image):
        img_in_orig = file_in
    else:
        img_in_orig = nib.load(file_in)

    multimodel = type(task_id) is list

    if task_name == "total":
        class_map_parts = class_map_5_parts
        map_taskid_to_partname = map_taskid_to_partname_ct
    # elif task_name == "total_mr":
    #     class_map_parts = class_map_parts_mr
    #     map_taskid_to_partname = map_taskid_to_partname_mr
    # elif task_name == "headneck_muscles":
    #     class_map_parts = class_map_parts_headneck_muscles
    #     map_taskid_to_partname = map_taskid_to_partname_headneck_muscles
    else:
        raise NotImplementedError(task_name)

    if type(resample) is float:
        resample = [resample, resample, resample]

    if len(img_in_orig.shape) == 2:
        raise ValueError("TotalSegmentator does not work for 2D images. Use a 3D image.")
    if len(img_in_orig.shape) > 3:
        print(f"WARNING: Input image has {len(img_in_orig.shape)} dimensions. Only using first three dimensions.")
        img_in_orig = nib.Nifti1Image(img_in_orig.get_fdata()[:, :, :, 0], img_in_orig.affine)

    img_dtype = img_in_orig.get_data_dtype()
    if img_dtype.fields is not None:
        raise TypeError(f"Invalid dtype {img_dtype}. Expected a simple dtype, not a structured one.")

    # takes ~0.9s for medium image
    img_in = nib.Nifti1Image(img_in_orig.get_fdata(), img_in_orig.affine)  # copy img_in_orig

    if crop is not None:
        if type(crop) is str:
            if crop == "lung" or crop == "pelvis":
                crop_mask_img = combine_masks(crop_path, crop)
            else:
                crop_mask_img = nib.load(crop_path / f"{crop}.nii.gz")
        else:
            crop_mask_img = crop
        img_in, bbox = crop_to_mask(img_in, crop_mask_img, addon=crop_addon, dtype=np.int32,
                                    verbose=verbose)
        if not quiet:
            print(f"  cropping from {crop_mask_img.shape} to {img_in.shape}")

    img_in = as_closest_canonical(img_in)

    # if resample is not None:
    #     if not quiet: print("Resampling...")
    #     st = time.time()
    #     img_in_shape = img_in.shape
    #     img_in_zooms = img_in.header.get_zooms()
    #     img_in_rsp = change_spacing(img_in, resample,
    #                                 order=3, dtype=np.int32,
    #                                 nr_cpus=nr_threads_resampling)  # 4 cpus instead of 1 makes it a bit slower
    #     if verbose:
    #         print(f"  from shape {img_in.shape} to shape {img_in_rsp.shape}")
    #     if not quiet: print(f"  Resampled in {time.time() - st:.2f}s")
    # else:
    #     img_in_rsp = img_in
    #
    # nib.save(img_in_rsp, os.path.join(tmp_dir, "s01_0000.nii.gz"))

    # todo important: change
    nr_voxels_thr = 512 * 512 * 900
    # nr_voxels_thr = 256*256*900
    img_parts = ["s01"]
    # ss = img_in_rsp.shape
    # If image to big then split into 3 parts along z axis. Also make sure that z-axis is at least 200px otherwise
    # splitting along it does not really make sense.
    # do_triple_split = np.prod(ss) > nr_voxels_thr and ss[2] > 200 and multimodel
    # if force_split:
    #     do_triple_split = True
    # if do_triple_split:
    #     if not quiet: print("Splitting into subparts...")
    #     img_parts = ["s01", "s02", "s03"]
    #     third = img_in_rsp.shape[2] // 3
    #     margin = 20  # set margin with fixed values to avoid rounding problem if using percentage of third
    #     img_in_rsp_data = img_in_rsp.get_fdata()
    #     nib.save(nib.Nifti1Image(img_in_rsp_data[:, :, :third + margin], img_in_rsp.affine),
    #              tmp_dir / "s01_0000.nii.gz")
    #     nib.save(nib.Nifti1Image(img_in_rsp_data[:, :, third + 1 - margin:third * 2 + margin], img_in_rsp.affine),
    #              tmp_dir / "s02_0000.nii.gz")
    #     nib.save(nib.Nifti1Image(img_in_rsp_data[:, :, third * 2 + 1 - margin:], img_in_rsp.affine),
    #              tmp_dir / "s03_0000.nii.gz")

    if task_name == "total" and resample is not None and resample[0] < 3.0:
        # overall speedup for 15mm model roughly 11% (GPU) and 100% (CPU)
        # overall speedup for  3mm model roughly  0% (GPU) and  10% (CPU)
        # (dice 0.001 worse on test set -> ok)
        # (for lung_trachea_bronchia somehow a lot lower dice)
        step_size = 0.8
    else:
        step_size = 0.5

    st = time.time()
    if multimodel:  # if running multiple models

        # only compute model parts containing the roi subset
        if roi_subset is not None:
            part_names = []
            new_task_id = []
            for part_name, part_map in class_map_parts.items():
                if any(organ in roi_subset for organ in part_map.values()):
                    # get taskid associated to model part_name
                    map_partname_to_taskid = {v: k for k, v in map_taskid_to_partname.items()}
                    new_task_id.append(map_partname_to_taskid[part_name])
                    part_names.append(part_name)
            task_id = new_task_id
            if verbose:
                print(f"Computing parts: {part_names} based on the provided roi_subset")

        if test == 0:
            class_map_inv = {v: k for k, v in class_map[task_name].items()}
            # (tmp_dir / "parts").mkdir(exist_ok=True)
            seg_combined = {}
            # iterate over subparts of image
            for img_part in img_parts:
                img_shape = nib.load(os.path.join(tmp_dir, f"{img_part}_0000.nii.gz")).shape
                seg_combined[img_part] = np.zeros(img_shape, dtype=np.uint8)
            # Run several tasks and combine results into one segmentation
            for idx, tid in enumerate(task_id):
                if not quiet:
                    print(f"Predicting part {idx + 1} of {len(task_id)} ...")
                with nostdout(verbose):
                    # nnUNet_predict(tmp_dir, tmp_dir, tid, model, folds, trainer, tta,
                    #                nr_threads_resampling, nr_threads_saving)
                    nnUNetv2_predict(tmp_dir, tmp_dir, tid, model, folds, trainer, tta,
                                     nr_threads_resampling, nr_threads_saving,
                                     device=device, quiet=quiet, step_size=step_size)
                # iterate over models (different sets of classes)
                for img_part in img_parts:
                    os.rename(os.path.join(tmp_dir, f"{img_part}.nii.gz"),
                              os.path.join(tmp_dir, f"{img_part}_{tid}.nii.gz"))
                    seg = nib.load(os.path.join(tmp_dir, f"{img_part}_{tid}.nii.gz")).get_fdata()
                    for jdx, class_name in class_map_parts[map_taskid_to_partname[tid]].items():
                        seg_combined[img_part][seg == jdx] = class_map_inv[class_name]
            # iterate over subparts of image
            for img_part in img_parts:
                nib.save(nib.Nifti1Image(seg_combined[img_part], img_in_rsp.affine),
                         os.path.join(tmp_dir, f"{img_part}.nii.gz"))
        # elif test == 1:
        #     print("WARNING: Using reference seg instead of prediction for testing.")
        #     shutil.copy(Path("tests") / "reference_files" / "example_seg.nii.gz", tmp_dir / "s01.nii.gz")
    else:
        if not quiet:
            print("Predicting...")
        if test == 0:
            with nostdout(verbose):
                # nnUNet_predict(tmp_dir, tmp_dir, task_id, model, folds, trainer, tta,
                #                nr_threads_resampling, nr_threads_saving)
                nnUNetv2_predict(tmp_dir, tmp_dir, task_id, model, folds, trainer, tta,
                                 nr_threads_resampling, nr_threads_saving,
                                 device=device, quiet=quiet, step_size=step_size)
        # elif test == 2:
        #     print("WARNING: Using reference seg instead of prediction for testing.")
        #     shutil.copy(Path("tests") / "reference_files" / "example_seg_fast.nii.gz", tmp_dir / f"s01.nii.gz")
        # elif test == 3:
        #     print("WARNING: Using reference seg instead of prediction for testing.")
        #     shutil.copy(Path("tests") / "reference_files" / "example_seg_lung_vessels.nii.gz",
        #                 tmp_dir / "s01.nii.gz")
    if not quiet:
        print(f"  Predicted in {time.time() - st:.2f}s")

    # # Combine image subparts back to one image
    # if do_triple_split:
    #     combined_img = np.zeros(img_in_rsp.shape, dtype=np.uint8)
    #     combined_img[:, :, :third] = nib.load(tmp_dir / "s01.nii.gz").get_fdata()[:, :, :-margin]
    #     combined_img[:, :, third:third * 2] = nib.load(tmp_dir / "s02.nii.gz").get_fdata()[:, :, margin - 1:-margin]
    #     combined_img[:, :, third * 2:] = nib.load(tmp_dir / "s03.nii.gz").get_fdata()[:, :, margin - 1:]
    #     nib.save(nib.Nifti1Image(combined_img, img_in_rsp.affine), tmp_dir / "s01.nii.gz")

    img_pred = nib.load(os.path.join(tmp_dir, "s01.nii.gz"))

    # Currently only relevant for T304 (appendicular bones)
    img_pred = remove_auxiliary_labels(img_pred, task_name)

    # Postprocessing multilabel (run here on lower resolution)
    if task_name == "body":
        img_pred_pp = keep_largest_blob_multilabel(img_pred.get_fdata().astype(np.uint8),
                                                   class_map[task_name], ["body_trunc"], debug=False, quiet=quiet)
        img_pred = nib.Nifti1Image(img_pred_pp, img_pred.affine)

    if task_name == "body":
        vox_vol = np.prod(img_pred.header.get_zooms())
        size_thr_mm3 = 50000 / vox_vol
        img_pred_pp = remove_small_blobs_multilabel(img_pred.get_fdata().astype(np.uint8),
                                                    class_map[task_name], ["body_extremities"],
                                                    interval=[size_thr_mm3, 1e10], debug=False, quiet=quiet)
        img_pred = nib.Nifti1Image(img_pred_pp, img_pred.affine)

    # if preview:
    #     from totalsegmentator.preview import generate_preview
    #     # Generate preview before upsampling so it is faster and still in canonical space
    #     # for better orientation.
    #     if not quiet: print("Generating preview...")
    #     st = time.time()
    #     smoothing = 20
    #     preview_dir = file_out.parent if multilabel_image else file_out
    #     generate_preview(img_in_rsp, preview_dir / f"preview_{task_name}.png", img_pred.get_fdata(), smoothing,
    #                      task_name)
    #     if not quiet: print(f"  Generated in {time.time() - st:.2f}s")

    # Statistics calculated on the 3mm downsampled image are very similar to statistics
    # calculated on the original image. Volume often completely identical. For intensity
    # some more change but still minor.
    #
    # Speed:
    # stats on 1.5mm: 37s
    # stats on 3.0mm: 4s    -> great improvement
    # stats = None
    # if statistics:
    #     if not quiet: print("Calculating statistics fast...")
    #     st = time.time()
    #     if file_out is not None:
    #         stats_dir = file_out.parent if multilabel_image else file_out
    #         stats_dir.mkdir(exist_ok=True)
    #         stats_file = stats_dir / "statistics.json"
    #     else:
    #         stats_file = None
    #     stats = get_basic_statistics(img_pred.get_fdata(), img_in_rsp, stats_file,
    #                                  quiet, task_name, exclude_masks_at_border, roi_subset,
    #                                  metric=stats_aggregation)
    #     if not quiet: print(f"  calculated in {time.time() - st:.2f}s")

    if resample is not None:
        if not quiet:
            print("Resampling...")
        if verbose:
            print(f"  back to original shape: {img_in_shape}")
        # Use force_affine otherwise output affine sometimes slightly off (which then is even increased
        # by undo_canonical)
        img_pred = change_spacing(img_pred, resample, img_in_shape,
                                  order=0, dtype=np.uint8, nr_cpus=nr_threads_resampling,
                                  force_affine=img_in.affine)

    if verbose:
        print("Undoing canonical...")
    img_pred = undo_canonical(img_pred, img_in_orig)

    if crop is not None:
        if verbose:
            print("Undoing cropping...")
        img_pred = undo_crop(img_pred, img_in_orig, bbox)

    check_if_shape_and_affine_identical(img_in_orig, img_pred)

    img_data = img_pred.get_fdata().astype(np.uint8)
    if save_binary:
        img_data = (img_data > 0).astype(np.uint8)

    # Reorder labels if needed
    if v1_order and task_name == "total":
        img_data = reorder_multilabel_like_v1(img_data, class_map["total"], class_map["total_v1"])
        label_map = class_map["total_v1"]
    else:
        label_map = class_map[task_name]

    # Keep only voxel values corresponding to the roi_subset
    if roi_subset is not None:
        label_map = {k: v for k, v in label_map.items() if v in roi_subset}
        img_data *= np.isin(img_data, list(label_map.keys()))

    # Prepare output nifti
    # Copy header to make output header exactly the same as input. But change dtype otherwise it will be
    # float or int and therefore the masks will need a lot more space.
    # (infos on header: https://nipy.org/nibabel/nifti_images.html)
    new_header = img_in_orig.header.copy()
    new_header.set_data_dtype(np.uint8)
    img_out = nib.Nifti1Image(img_data, img_pred.affine, new_header)
    img_out = add_label_map_to_nifti(img_out, label_map)

    if file_out is not None and skip_saving is False:
        if not quiet:
            print("Saving segmentations...")

        # Select subset of classes if required
        selected_classes = class_map[task_name]
        if roi_subset is not None:
            selected_classes = {k: v for k, v in selected_classes.items() if v in roi_subset}

        st = time.time()
        if multilabel_image:
            os.makedirs(os.path.dirname(file_out), exist_ok=True)
            nib.save(img_out, file_out)
        else:
            os.makedirs(file_out, exist_ok=True)
            # save each class as a separate binary image

            if np.prod(img_data.shape) > 512 * 512 * 1000:
                print("Shape of output image is very big. Setting nr_threads_saving=1 to save memory.")
                nr_threads_saving = 1

            # Code for single threaded execution  (runtime:24s)
            nr_threads_saving = 1
            if nr_threads_saving == 1:
                for k, v in selected_classes.items():
                    binary_img = img_data == k
                    output_path = os.path.join(file_out, f"{v}.nii.gz")
                    nib.save(nib.Nifti1Image(binary_img.astype(np.uint8), img_pred.affine, new_header),
                             output_path)
            # else:
            # # Code for multithreaded execution
            # #   Speed with different number of threads:
            # #   1: 46s, 2: 24s, 6: 11s, 10: 8s, 14: 8s
            # nib.save(img_pred, tmp_dir / "s01.nii.gz")
            # _ = p_map(
            #     partial(save_segmentation_nifti, tmp_dir=tmp_dir, file_out=file_out, nora_tag=nora_tag,
            #             header=new_header, task_name=task_name, quiet=quiet),
            #     selected_classes.items(), num_cpus=nr_threads_saving, disable=quiet)
            #
            # # Multihreaded saving with same functions as in nnUNet -> same speed as p_map
            # # pool = Pool(nr_threads_saving)
            # # results = []
            # # for k, v in selected_classes.items():
            # #     results.append(pool.starmap_async(save_segmentation_nifti, ((k, v, tmp_dir, file_out, nora_tag),) ))
            # # _ = [i.get() for i in results]  # this actually starts the execution of the async functions
            # # pool.close()
            # # pool.join()
        if not quiet:
            print(f"  Saved in {time.time() - st:.2f}s")

        # Postprocessing single files
        #    (these not directly transferable to multilabel)

        # Lung mask does not exist since I use 6mm model. Would have to save lung mask from 6mm seg.
        # if task_name == "lung_vessels":
        #     remove_outside_of_mask(file_out / "lung_vessels.nii.gz", file_out / "lung.nii.gz")

        # if task_name == "heartchambers_test":
        #     remove_outside_of_mask(file_out / "heart_myocardium.nii.gz", file_out / "heart.nii.gz", addon=5)
        #     remove_outside_of_mask(file_out / "heart_atrium_left.nii.gz", file_out / "heart.nii.gz", addon=5)
        #     remove_outside_of_mask(file_out / "heart_ventricle_left.nii.gz", file_out / "heart.nii.gz", addon=5)
        #     remove_outside_of_mask(file_out / "heart_atrium_right.nii.gz", file_out / "heart.nii.gz", addon=5)
        #     remove_outside_of_mask(file_out / "heart_ventricle_right.nii.gz", file_out / "heart.nii.gz", addon=5)
        #     remove_outside_of_mask(file_out / "aorta.nii.gz", file_out / "heart.nii.gz", addon=5)
        #     remove_outside_of_mask(file_out / "pulmonary_artery.nii.gz", file_out / "heart.nii.gz", addon=5)

    return img_out, img_in_orig
